# CelebA image generation using Conditional DCGAN
import copy
import json
import os

import numpy as np
import torch
import pickle

from Causal_Train_By_Components.Causal_TrainGraph import set_trainGraph
from Causal_Train_By_Components.GAN_Evaluation.EvaluateCausalGAN import evaluate_after_epochs, getdoKey
from Causal_Train_By_Components.mnistControllerModel import get_discriminators

from Causal_Train_By_Components.mnistControllerModel import get_generated_labels, get_generators
from CausalTwoDiscrimMechTrain.ConstantFunctions import save_checkpoint, asKey, draw_true_graph
from CausalTwoDiscrimMechTrain.ControllerConstants import get_multiple_labels_fill

from CausalTwoDiscrimMechTrain.Experiment_Class import Experiment
from CausalTwoDiscrimMechTrain.Lables_Gradient_Penalty import calc_gradient_penalty
# from DataGenerationMechanisms.mnist_generate_synthetic import get_synthetic_dist, get_bayesian_network
from CausalMNISTAddition.GroundTruth.Synthetic_Distribution_Mnist import get_bayesian_network, get_intv_dist, \
    get_synthetic_dist
from CausalTwoDiscrimMechTrain.PlotLabels import plot_lines


def get_dataset(Exp, label, dno):

    dataset = []
    for feature in ["feature"]:
        file_name = Exp.file_roots[dno] + label + feature + ".pkl"

        with open(file_name, 'rb') as fp:
            label_data = pickle.load(fp)
        label_data = torch.FloatTensor(label_data)
        label_size = len(label_data)
        dataset.append(label_data.view(label_size, 1))

    result_dataset = torch.cat(dataset, 1).to(Exp.DEVICE)
    print(result_dataset.shape)
    return result_dataset


def get_intv_dataset(Exp):
    intv_dataset_list = []
    for dno in range(1, Exp.num_datasets):
        dataset = []
        for label in Exp.label_names:
            file_name = Exp.file_roots[dno] + label + ".pkl"
            with open(file_name, 'rb') as fp:
                label_data = pickle.load(fp)
            label_data = torch.FloatTensor(label_data)
            label_size = len(label_data)
            # plot_labels("intv "+label, label_data.view(label_size, 1))
            dataset.append(label_data.view(label_size, 1))
        dataset = torch.cat(dataset, 1).to(Exp.DEVICE)

        intv_dataset_list.append(dataset)

    intv_dataset = None
    if len(intv_dataset_list) != 0:
        intv_dataset = torch.cat(intv_dataset_list, 0)

    return intv_dataset



def train_CausalController(Exp, cur_mechs, label_generators, G_optimizers, label_discriminator, D_optimizer,
                           dataset_dict_batches, batchno):
    G_loss=torch.zeros(1).to(Exp.DEVICE)
    for interv_no, (intv_key_prev, dataset_batches) in enumerate(dataset_dict_batches.items()):


        data_input = dataset_batches[batchno]
        compare_Var=[]
        intervened_Var=[]
        for mech in cur_mechs:
            ret = [lb for lb in Exp.train_mech_dict[mech][interv_no]["compare"] if not lb in compare_Var]
            compare_Var+= ret

            intervened_Var += Exp.train_mech_dict[mech][interv_no]["intv"].keys()


        mini_batch = data_input.size()[0]
        indices = [Exp.label_names.index(lb) for lb in compare_Var]
        current_real_label = data_input[:, indices].type(torch.LongTensor).view(-1, len(indices)).to(Exp.DEVICE)

        dims_list = [Exp.label_dim[lb]["feature"] for lb in compare_Var]

        real_labels_fill = get_multiple_labels_fill(Exp, current_real_label, dims_list, isImage_labels=False)  # !!!


        intv_tensor_dict = {}
        for lbid, intv_lb in enumerate(intervened_Var): #if no intervention then no looping
            index = [Exp.label_names.index(intv_lb)]
            parent_intv_label = data_input[:, index].type(torch.LongTensor).view(-1, 1).to(Exp.DEVICE) #for each intv parent
            dims_list = [Exp.label_dim[intv_lb]["feature"]]
            intv_parent_fill = get_multiple_labels_fill(Exp, parent_intv_label, dims_list, isImage_labels=False)
            intv_tensor_dict[intv_lb] = intv_parent_fill

        generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_tensor_dict, compare_Var, mini_batch)

        y_dims = sum([Exp.label_dim[lb]["feature"] for lb in compare_Var])
        ret = list(generated_labels_dict.values())
        generated_labels_fill = torch.cat(ret, 1).view(-1, y_dims)


        D_losses = []
        for crit_ in range(Exp.CRITIC_ITERATIONS):
            D_real_decision_obs = label_discriminator[interv_no](real_labels_fill).squeeze()
            D_fake_decision_obs = label_discriminator[interv_no](generated_labels_fill).squeeze()

            gp_obs = calc_gradient_penalty(label_discriminator[interv_no], real_labels_fill, generated_labels_fill, Exp.LAMBDA_GP,
                                           device=Exp.DEVICE)

            D_loss_obs = (-  (torch.mean(D_real_decision_obs) - torch.mean(D_fake_decision_obs)) + Exp.LAMBDA_GP * gp_obs)

            D_losses.append((D_loss_obs).data)  # just a loss list

            label_discriminator[interv_no].zero_grad()
            # gp_obs.backward(retain_graph=True)
            D_loss_obs.backward(retain_graph=True)
            D_optimizer[interv_no].step()

        # accumulating the generator losses for all interventions.
        D_fake_decision_obs = label_discriminator[interv_no](generated_labels_fill).squeeze()
        G_loss += -torch.mean(D_fake_decision_obs)


    # Back propagation
    for mech in cur_mechs:
        label_generators[mech].zero_grad()

    G_loss.backward()

    for mech in cur_mechs:
        G_optimizers[mech].step()

    D_loss = torch.mean(torch.FloatTensor(D_losses))  # just mean of losses

    return G_loss.data, D_loss.data


def labelMain(Exp, cur_mechs, label_generators, G_optimizers, discriminators, D_optimizers, dataset_dict,
              tvd_diff, kl_diff):
    dataset_dict_batches = {}

    num_batches=0
    for key, each_dataset in dataset_dict.items():
        real_dataloader = torch.utils.data.DataLoader(dataset=each_dataset,
                                                      batch_size=Exp.batch_size,
                                                      shuffle=False)

        batch_list = []
        for data_input in real_dataloader:
            data_input = torch.squeeze(data_input)
            batch_list.append(data_input)

        dataset_dict_batches[key] = batch_list
        num_batches = len(batch_list)

    iteration = 0

    for batchno in range(num_batches):

        g_loss, d_loss = train_CausalController(Exp, cur_mechs, label_generators, G_optimizers, discriminators,
                                                D_optimizers, dataset_dict_batches, batchno)

        print('Epoch [%d/%d], Step [%d/%d],' % (
            Exp.curr_epoochs + 1, Exp.num_epochs, iteration + 1, num_batches),
              'mechanism: ',cur_mechs,  ' D_loss: %.4f, G_loss: %.4f' % (d_loss.data, g_loss.data))

        # Annealing
        tot_iter = Exp.curr_epoochs * num_batches + iteration
        if tot_iter % 100 == 0:
            Exp.anneal_temperature(tot_iter)

        if (iteration + 1) % int(num_batches / Exp.PLOTS_PER_EPOCH) == 0:
            tvd_diff, kl_diff= evaluate_after_epochs(Exp, cur_mechs, label_generators , dataset_dict,  tvd_diff, kl_diff)

        Exp.D_avg_losses.append(torch.mean(d_loss))
        Exp.G_avg_losses.append(torch.mean(g_loss))
        iteration += 1
#
    # if (Exp.curr_epoochs <= 50 and (Exp.curr_epoochs + 1) % 5 == 0) or (Exp.curr_epoochs > 50 and (Exp.curr_epoochs + 1) % 15 == 0):
    if (Exp.curr_epoochs + 1) % 5 == 0:
        var_list= "".join(x for x in cur_mechs)
        save_checkpoint(Exp, Exp.SAVED_PATH, cur_mechs, label_generators, G_optimizers, {var_list:discriminators}, {var_list: D_optimizers})
        print(Exp.curr_epoochs,":model saved at ", Exp.SAVED_PATH)
    #
    # compare_Var = Exp.train_mech_dict[cur_mech][0]["compare"]
    # intv_key = Exp.train_mech_dict[cur_mech][0]["intv"]
    # query = getdoKey(compare_Var, intv_key)
    #
    # return tvd_diff[query][-1]
    return 100






if __name__ == "__main__":

    # temp, dlayer, gp
    Exp = Experiment("Exp1", set_trainGraph,
                     dist_thresh=0.15,
                     causal_hierarchy=2,
                     Temperature=1,
                     temp_min=0.1,
                     G_hid_dims=[256, 256],
                     D_hid_dims=[256, 256],
                     IMAGE_FILTERS=[128, 64, 32],
                     CRITIC_ITERATIONS=5,
                     LAMBDA_GP=1,
                     learning_rate=2 * 1e-4,
                     Synthetic_Sample_Size=10000,
                     intv_Sample_Size=10000,
                     batch_size=200,
                     features=["feature"],
                     noise_states=100,
                     latent_state=16,
                     Data_intervs=[{}],
                     num_epochs=200,
                     new_experiment=True,
                     obs_state=3
                     )


    print(Exp.Data_intervs)
    Exp.intv_batch_size = Exp.batch_size
    # True scm

    os.makedirs(Exp.SAVED_PATH, exist_ok=True)
    dag_name = Exp.Complete_DAG_desc + ".txt"


    # Load previous model results also


    Exp.LOAD_MODEL_PATH = "previous_model_path"
    Exp.load_which_models = {"X0": True, "X1": False, "X2": False,
                             "W0": True, "W1": True,
                             "Y0": True, "Y1": False}


    # c_components = [{"num_dataset": 1, "cur_mechs": ["X0", "X1", "X2", "W0","W1", "Y0", "Y1"]}]
    c_components= [{"num_dataset":1, "cur_mechs" : ["W0"], },
                   {"num_dataset":1, "cur_mechs" : ["W1"]},
                    {"num_dataset":1, "cur_mechs" : ["X0","Y0"]},
                    {"num_dataset":1, "cur_mechs" : ["X1","X2","Y1"]}]

    comp_no=3

    prev_mechs=[comp["cur_mechs"]  for comp in c_components[0:comp_no]]
    prev_mechs.append(["X0","W0", "W1", "Y0"])
    prev_mechs.append(Exp.label_names)
    # for train_no, each_com in enumerate(c_components):
    ##############****************##############
    each_com = c_components[comp_no]
    cur_mechs = each_com["cur_mechs"]
    Exp.num_datasets = each_com["num_dataset"]

    file_name = Exp.LOAD_MODEL_PATH + "/model_journeys.txt"  # if not previous model, then it saves its own.
    journey_dict = {"journeys": []}
    if os.path.exists(file_name):
        with open(file_name) as f:
            data = f.read()
        journey_dict = json.loads(data)

    mech_str= "".join(x for x in cur_mechs)
    journey_dict["journeys"].append({"cur_mech":mech_str, "path":Exp.SAVED_PATH})
    new_file = Exp.SAVED_PATH + "/model_journeys.txt"
    with open(new_file, 'w') as fp:
        fp.write(json.dumps(journey_dict))


    label_generators, optimizersMech = get_generators(Exp, Exp.load_which_models)

    discriminatorsMech, doptimizersMech = get_discriminators(Exp, cur_mechs, Exp.load_which_models)  #


    # load datasets without images
    dataset_dict = {}

    for dno in range(Exp.num_datasets):
        each_dataset = []
        for label in Exp.label_names:
            if label not in Exp.image_labels:
                each_dataset.append(get_dataset(Exp, label, dno))

        dataset_dict[asKey(Exp.Data_intervs[dno])] = torch.cat(each_dataset, 1).to(Exp.DEVICE)


    tvd_diff = {}
    kl_diff = {}


    obs_query = getdoKey(Exp.label_names, {})
    tvd_diff[obs_query] = []
    kl_diff[obs_query] = []

    compare_Var=[]
    for mech in cur_mechs:
        ret = [lb for lb in Exp.train_mech_dict[mech][0]["compare"] if not lb in compare_Var]
        compare_Var += ret

    obs_query = getdoKey(compare_Var, {})
    tvd_diff[obs_query] = []
    kl_diff[obs_query] = []




    # if True in Exp.load_which_models.values() or train_no>0 :
    if True in Exp.load_which_models.values() :
        print("loading previous tvd diffs")
        for p_mechs in prev_mechs:
            dist = getdoKey(p_mechs, {})
            if os.path.exists(Exp.LOAD_MODEL_PATH + "/tvd/" + dist):
                tvd_diff[dist] = torch.load(Exp.LOAD_MODEL_PATH + "/tvd/" + dist).tolist()
                kl_diff[dist] = torch.load(Exp.LOAD_MODEL_PATH + "/kl/" + dist).tolist()


    mech_tvd = 0
    print("Starting training new mechanism")


    for epoch in range(Exp.num_epochs):
        Exp.curr_epoochs = epoch
        mech_tvd = labelMain(Exp, cur_mechs, label_generators, optimizersMech, discriminatorsMech, doptimizersMech, dataset_dict, tvd_diff, kl_diff)
